/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License. 
*/

#pragma once
#include "./custom_type.h"
#include "./data_transfer.h"
#include "./kernel_const.h"
#include "./kernel_utils.h"
#include "kernel_operator.h"



/**
 * @brief: 批量逻辑矩阵padding
 * @param[in] L1_M0; 完整矩阵M维度分块大小
 * @param[in] L1_N0: 完整矩阵N维度分块大小
 * @param[in] layout: 矩阵行/列优先存储类型
 * @param[in] zeroPaddingM: 完整矩阵M维度大小
 * @param[in] zeroPaddingN: 完整矩阵N维度大小
 * @param[in] batchCount: 矩阵批量数
 * @param[in] d_validM: 矩阵M维度有效长度
 * @param[in] d_validN: 矩阵N维度有效长度
 * @param[in] d_matrixPointer: 矩阵首地址
 * @param[out] d_matrixPointerPadding: padding后矩阵首地址
 * @param[in] paddingDir: padding方向。行优先: 0-按L1_M0向上对齐, 1-Zz, 2-Nz; 列优先: 0-按L1_N0向上对齐, 1-Nn, 2-Zn
*/
template<
    uint32_t L1_M0,
    uint32_t L1_N0
>
[aicore] inline __attribute__((always_inline)) void BatchMatrixPadding(
    layoutType layout,  
    uint32_t zeroPaddingM, 
    uint32_t zeroPaddingN,
    uint32_t batchCount,
    __gm__ uint32_t*  d_validM, 
    __gm__ uint32_t*  d_validN, 
    __gm__ half**  d_matrixPointer, 
    __gm__ uint8_t *d_isPadding, 
    __gm__ half** d_matrixPointerPadding, 
    uint8_t paddingDir
){

    AscendC::TBuf<AscendC::TPosition::VECIN> ub_buf;
    AscendC::TPipe ub_pipe;
    ub_pipe.InitBuffer(ub_buf, UB_BYTES);
    AscendC::LocalTensor<uint8_t> ub_tensor = ub_buf.Get<uint8_t>();
    ub_pipe.Destroy();
    
    static constexpr uint32_t ub_paddingPingpongNum = 2;
    AscendC::LocalTensor<half> ub_paddingBuf[ub_paddingPingpongNum];

    uint64_t ub_paddingPingpongBytes = UB_BYTES / ub_paddingPingpongNum; 
    uint64_t ub_paddingPingpongSize = ub_paddingPingpongBytes / sizeof(half); 
    #pragma unroll
    for(uint32_t i = 0; i < ub_paddingPingpongNum; i++){
        ub_paddingBuf[i] = ub_tensor[ i * ub_paddingPingpongBytes ].template ReinterpretCast<half>();
    }
    #pragma unroll
    for(uint32_t i = 0; i < ub_paddingPingpongNum; i++){
        AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(i));
    }

    AscendC::GlobalTensor<half> gm_matrix;
    AscendC::GlobalTensor<half> gm_matrixPadding;

    uint32_t blockRowSize;
    uint32_t blockRowNum;
    if(layout == RowMajor){
        blockRowSize = L1_N0; 
        blockRowNum = L1_M0; 
    }else{
        blockRowSize = L1_M0; 
        blockRowNum = L1_N0; 
    }

    uint32_t blockSize = blockRowNum * blockRowSize;
    uint32_t datablockSize = DATABLOCK_BYTES / sizeof(half);

    uint32_t gm_stride = 0; 
    uint32_t gm_paddingStride = 0;
    uint32_t matrixRowSize = 0;
    uint32_t matrixRowSizeAlign32B = 0; 
    uint32_t matrixRowNum = 0;
    uint32_t matrixRowSizeBlockLoops = 0; 
    uint32_t matrixRowNumBlockLoops = 0; 
    uint8_t isMatrixRowExceedBuf = 0;
    uint32_t matrixRowSizeTaskLoops = 0;
    uint32_t matrixRowNumTaskLoops = 0;
    uint32_t tasksNumPerRow = 0;
    uint32_t rowNumPerBuf = 0;

    uint32_t aivIdx = AscendC::GetBlockIdx();
    uint32_t aivSum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
    

    uint32_t curAicoreTask = 0;
    uint32_t curTasksSum = 0;
    uint32_t batchNum = -1;
    uint32_t loopSum = 0;

    for(uint32_t loopIdx = 0; loopIdx < loopSum+1; loopIdx++){

        if( loopIdx == loopSum ){ 
            batchNum++;
            if(batchNum < batchCount){
                if(d_isPadding[batchNum]){
                    curTasksSum = loopSum;
                    gm_matrix.SetGlobalBuffer( (__gm__ half*)d_matrixPointer[batchNum] );
                    gm_matrixPadding.SetGlobalBuffer( (__gm__ half*)d_matrixPointerPadding[batchNum] );
                    gm_stride = (layout == RowMajor ? zeroPaddingN : zeroPaddingM);
                    matrixRowSize = (layout == RowMajor ? d_validN[batchNum] : d_validM[batchNum] );
                    matrixRowSizeAlign32B = RoundUp<uint32_t>(matrixRowSize, datablockSize); 
                    matrixRowNum = (layout == RowMajor ? d_validM[batchNum] : d_validN[batchNum] );
                    matrixRowSizeBlockLoops = CeilDiv<uint32_t>(matrixRowSize, blockRowSize); 
                    matrixRowNumBlockLoops = CeilDiv<uint32_t>(matrixRowNum, blockRowNum);
                    if(paddingDir == 0){
                        gm_paddingStride = RoundUp<uint32_t>(gm_stride, blockRowSize); 
                    }else if(paddingDir == 1){
                        gm_paddingStride = blockSize; 
                    }else if(paddingDir == 2){
                        gm_paddingStride = matrixRowNumBlockLoops * blockSize;
                    }
                    isMatrixRowExceedBuf = (matrixRowSize > ub_paddingPingpongSize);
                    if(isMatrixRowExceedBuf){ 
                        tasksNumPerRow = CeilDiv<uint32_t>(matrixRowSize, ub_paddingPingpongSize);
                        matrixRowSizeTaskLoops = tasksNumPerRow;
                        matrixRowNumTaskLoops = matrixRowNum;
                        loopSum += matrixRowSizeTaskLoops * matrixRowNumTaskLoops;
                    }else{ 
                        rowNumPerBuf = (ub_paddingPingpongSize - blockRowSize ) / matrixRowSizeAlign32B;
                        matrixRowNumTaskLoops = CeilDiv<uint32_t>(matrixRowNum, rowNumPerBuf);
                        loopSum += matrixRowNumTaskLoops;
                    }
                }else{
                    loopIdx--;
                    continue;
                }
            }else{
                continue;
            }
        }

        if(loopIdx % aivSum !=aivIdx ){
            continue;
        }

        uint32_t taskIdx = loopIdx - curTasksSum;
        uint8_t curTaskBufIdx = curAicoreTask % ub_paddingPingpongNum;

        if(isMatrixRowExceedBuf){

            uint32_t matrixRowSizeTaskIdx = taskIdx % matrixRowSizeTaskLoops;
            uint32_t matrixRowNumTaskIdx = taskIdx / matrixRowSizeTaskLoops;
            uint32_t matrixRowSizeTaskAddr = matrixRowSizeTaskIdx * ub_paddingPingpongSize;
            uint32_t matrixRowNumTaskAddr = matrixRowNumTaskIdx;
            uint32_t matrixTaskAddr = matrixRowNumTaskAddr * gm_stride + matrixRowSizeTaskAddr; 
            uint32_t gm_matrixRowSizeActual = (matrixRowSizeTaskIdx == matrixRowSizeTaskLoops-1 ? 
                                                matrixRowSize - matrixRowSizeTaskAddr
                                                : ub_paddingPingpongSize);
            uint32_t gm_matrixRowNumActual = 1;

            uint32_t matrixRowSizeBlockIdx = matrixRowSizeTaskAddr / blockRowSize; 
            uint32_t matrixRowNumBlockIdx = matrixRowNumTaskAddr / blockRowNum;
            uint32_t matrixRowNumBlockInnerIdx = matrixRowNumTaskAddr % blockRowNum;
            uint32_t matrixTaskPaddingAddr; 
            uint32_t copyOutSegmentNum; 
            uint32_t copyOutSegmentSize; 
            if(paddingDir == 0){
                matrixTaskPaddingAddr = matrixRowNumTaskAddr * gm_paddingStride + matrixRowSizeTaskAddr; 
                copyOutSegmentNum = gm_matrixRowNumActual; 
                copyOutSegmentSize = gm_matrixRowSizeActual; 
            }else if(paddingDir == 1){ // Zz/Nn
                matrixTaskPaddingAddr = matrixRowNumBlockIdx * matrixRowSizeBlockLoops * blockSize
                                      + matrixRowSizeBlockIdx * blockSize
                                      + matrixRowNumBlockInnerIdx * blockRowSize;
                copyOutSegmentNum = CeilDiv<uint32_t>(gm_matrixRowSizeActual, blockRowSize); 
                copyOutSegmentSize = blockRowSize; 
            }else if(paddingDir == 2){ // Nz/Zn
                matrixTaskPaddingAddr = matrixRowSizeBlockIdx * matrixRowNumBlockLoops * blockSize
                                      + matrixRowNumBlockIdx * blockSize
                                      + matrixRowNumBlockInnerIdx * blockRowSize;
                copyOutSegmentNum = CeilDiv<uint32_t>(gm_matrixRowSizeActual, blockRowSize); 
                copyOutSegmentSize = blockRowSize; 
            }

            AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(curTaskBufIdx));

            Gm2Ub<half>(
                ub_paddingBuf[curTaskBufIdx], 
                gm_matrix[matrixTaskAddr], 
                gm_matrixRowNumActual, 
                gm_matrixRowSizeActual, 
                0, 
                gm_stride, 
                0, 
                0
            );

            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(curTaskBufIdx));
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(curTaskBufIdx));

            Ub2Gm<half>(
                gm_matrixPadding[matrixTaskPaddingAddr], 
                ub_paddingBuf[curTaskBufIdx], 
                copyOutSegmentNum, 
                copyOutSegmentSize, 
                gm_paddingStride,
                0
            );

            
            AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(curTaskBufIdx));
        
        }else{ 

            uint32_t matrixRowNumTaskIdx = taskIdx;
            uint32_t matrixRowNumTaskAddr = matrixRowNumTaskIdx * rowNumPerBuf;
            uint32_t matrixTaskAddr = matrixRowNumTaskAddr * gm_stride;
            uint32_t gm_matrixRowSizeActual = matrixRowSize;
            uint32_t gm_matrixRowSizeActualAlign32B = matrixRowSizeAlign32B; 
            uint32_t gm_matrixRowNumActual = (matrixRowNumTaskIdx == matrixRowNumTaskLoops-1 ? 
                                                matrixRowNum - matrixRowNumTaskAddr
                                                : rowNumPerBuf );
            uint32_t matrixRowNumTaskRowAddr; 
            uint32_t matrixRowSizeBlockIdx; 
            uint32_t matrixRowNumBlockIdx; 

            uint32_t matrixRowNumBlockInnerIdx; 
            uint32_t matrixTaskPaddingAddr; 
            uint32_t copyOutSegmentNum; 
            uint32_t copyOutSegmentSize; 

            AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(curTaskBufIdx));
            Gm2Ub(
                ub_paddingBuf[curTaskBufIdx], 
                gm_matrix[matrixTaskAddr], 
                gm_matrixRowNumActual, 
                gm_matrixRowSizeActual, 
                0, 
                gm_stride, 
                0, 
                gm_matrixRowSizeActualAlign32B - gm_matrixRowSizeActual
            );

            AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(curTaskBufIdx));
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(curTaskBufIdx));

            for(uint32_t rowIdx = 0; rowIdx < gm_matrixRowNumActual; rowIdx++){
                matrixRowNumTaskRowAddr = matrixRowNumTaskAddr + rowIdx; 
                matrixRowSizeBlockIdx = 0;
                matrixRowNumBlockIdx = matrixRowNumTaskRowAddr / blockRowNum;
                matrixRowNumBlockInnerIdx = matrixRowNumTaskRowAddr % blockRowNum;
                if(paddingDir == 0){
                    matrixTaskPaddingAddr = matrixRowNumTaskRowAddr * gm_paddingStride; 
                    copyOutSegmentNum = 1; 
                    copyOutSegmentSize = gm_matrixRowSizeActual; 
                }else if(paddingDir == 1){ // Zz/Nn
                    matrixTaskPaddingAddr = matrixRowNumBlockIdx * matrixRowSizeBlockLoops * blockSize
                                          + matrixRowSizeBlockIdx * blockSize
                                          + matrixRowNumBlockInnerIdx * blockRowSize;
                    copyOutSegmentNum = CeilDiv<uint32_t>(gm_matrixRowSizeActual, blockRowSize); 
                    copyOutSegmentSize = blockRowSize; 
                }else if(paddingDir == 2){ // Nz/Zn
                    matrixTaskPaddingAddr = matrixRowSizeBlockIdx * matrixRowNumBlockLoops * blockSize
                                          + matrixRowNumBlockIdx * blockSize
                                          + matrixRowNumBlockInnerIdx * blockRowSize;
                    copyOutSegmentNum = CeilDiv<uint32_t>(gm_matrixRowSizeActual, blockRowSize); 
                    copyOutSegmentSize = blockRowSize; 
                }

                Ub2Gm<half>(
                    gm_matrixPadding[matrixTaskPaddingAddr], 
                    ub_paddingBuf[curTaskBufIdx][rowIdx * gm_matrixRowSizeActualAlign32B], 
                    copyOutSegmentNum, 
                    copyOutSegmentSize, 
                    gm_paddingStride, 
                    0
                );
                
                
            }

            AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(curTaskBufIdx));
            
        }

        curAicoreTask++;

    }

    #pragma unroll
    for(uint32_t i = 0; i < ub_paddingPingpongNum; i++){
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(i));
    }
}

